
import logging
from init import PATHS
LOGGER = logging.getLogger(__name__)
import os
import pandas as pd
import numpy as np
from C_PatentVariables import collecting_patstat, family_aggregation, firm_year_aggregation


def get_applnid_with_energy_codes(df_greendirty_codes, test=False):
    # CPC subgroup codes
    list_subgroupcodes = df_greendirty_codes['CPCfull'].drop_duplicates().tolist()
    df_EnergyPatentsCPC = collecting_appln_id_for_subgroup_level_codes(df_greendirty_codes, 'cpc', list_subgroupcodes, test)
    # IPC subgroup codes
    list_subgroupcodes = df_greendirty_codes['IPCfull'].drop_duplicates().tolist()
    df_EnergyPatentsIPC = collecting_appln_id_for_subgroup_level_codes(df_greendirty_codes, 'ipc', list_subgroupcodes, test)
    # The IPC table in PATSTAT lists some codes that are at the subclass level (4-digit long codes)
    # we therefore also need to check any subclass level code
    # i.e. if a patent is listed as having the IPC code "C10J" in PATSTAT, we want to capture it
    # NB: I checked whether all the codes listed in the CPC and IPC tables in PATSTAT. all codes are at the subgroup level except a few ipc at the subclass level
    mask1 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCsubclass']
    mask2 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCclass']
    list_fourdigitcodes = df_greendirty_codes[mask1 | mask2]['IPCsubclass'].drop_duplicates().tolist()
    df_EnergyPatentsIPCsubclass = collecting_appln_id_for_subgroup_level_codes(df_greendirty_codes, 'ipc', list_fourdigitcodes, test, level='subclass')
    df_EnergyPatents = merging_appln_id_from_cpc_and_ipc_codes(df_EnergyPatentsCPC, df_EnergyPatentsIPC, df_EnergyPatentsIPCsubclass)
    list_appln_id = df_EnergyPatents['appln_id'].drop_duplicates().tolist()
    return list_appln_id


def collecting_appln_id_for_subgroup_level_codes(df_greendirty_codes, scheme, list_subgroupcodes, test, level=''):
    LOGGER.info('       Collecting appln_id of patent with energy code {}{}.csv'.format(scheme.upper(), level))
    name_PATSTAT_table = {'cpc': '224', 'ipc': '209'}
    # Use table tls224 in Patstat to get the patstat id of all the applications falling into specific CPCs
    list_files = [f for f in os.listdir(PATHS.patstatglobal)if name_PATSTAT_table[scheme] in f]
    list_col = ['appln_id', '{}_class_symbol'.format(scheme)]
    df_EnergyPatents = []
    for file in list_files:
        LOGGER.info('       File: {}'.format(file))
        for num, dfchunk in enumerate(pd.read_csv(PATHS.patstatglobal / file, chunksize=500000, usecols=list_col, sep=',')):
            # Cleaning out the spaces of CPC codes from PATSTAT to make them comparable to our data
            dfchunk['{}_class_symbol'.format(scheme)] = dfchunk['{}_class_symbol'.format(scheme)].str.replace(' ', '', regex=True)
            # Filtering the codes that we are interested in
            dfchunk = dfchunk[dfchunk['{}_class_symbol'.format(scheme)].isin(list_subgroupcodes)].rename(columns={'{}_class_symbol'.format(scheme): scheme.upper()})
            df_EnergyPatents.append(dfchunk)
            if test:
                break
    df_EnergyPatents = pd.concat(df_EnergyPatents)
    # Adding info regarding the sector / type
    if level == '':
        df_EnergyPatents = pd.merge(df_EnergyPatents, df_greendirty_codes[['{}full'.format(scheme.upper()), 'Sector', 'Sub-sector', 'Type', 'Code', 'Include']], left_on=scheme.upper(), right_on='{}full'.format(scheme.upper()), how='left')
        df_EnergyPatents = df_EnergyPatents.drop(columns=['{}full'.format(scheme.upper())])
    else:
        mask1 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCsubclass']
        mask2 = df_greendirty_codes['Code'] == df_greendirty_codes['IPCclass']
        df_EnergyPatents = pd.merge(df_EnergyPatents, df_greendirty_codes[mask1 | mask2][['IPCsubclass', 'Sector', 'Sub-sector', 'Type', 'Code', 'Include']], left_on=scheme.upper(), right_on='IPCsubclass', how='inner')
        df_EnergyPatents = df_EnergyPatents.drop(columns=['IPCsubclass'])
    df_EnergyPatents = df_EnergyPatents.groupby(['appln_id'], as_index=False).agg(lambda x: ','.join([str(i) for i in x.unique()]))
    return df_EnergyPatents


def merging_appln_id_from_cpc_and_ipc_codes(df_EnergyPatentsCPC, df_EnergyPatentsIPC, df_EnergyPatentsIPCsubclass):
    # Combine IPC
    df_EnergyPatentsIPC = pd.concat([df_EnergyPatentsIPC, df_EnergyPatentsIPCsubclass]).drop_duplicates()
    df_EnergyPatentsIPC = df_EnergyPatentsIPC.groupby(['appln_id'], as_index=False).agg(lambda x: ','.join([str(i) for i in x.unique()]))
    df_EnergyPatentsIPC = df_EnergyPatentsIPC.rename(columns={'Sector': 'IPC_Sector', 'Sub-sector': 'IPC_Sub-sector', 'Type': 'IPC_Type', 'Code': 'IPC_Code', 'Include': 'IPC_Include'})
    df_EnergyPatentsCPC = df_EnergyPatentsCPC.rename(columns={'Sector': 'CPC_Sector', 'Sub-sector': 'CPC_Sub-sector', 'Type': 'CPC_Type', 'Code': 'CPC_Code', 'Include': 'CPC_Include'})
    # Combine IPC and CPC
    df_EnergyPatents = df_EnergyPatentsCPC.merge(df_EnergyPatentsIPC, on='appln_id', how='outer')
    LOGGER.info('       Number of patstat id connected to energy IPC or CPC: {}'.format(df_EnergyPatents['appln_id'].drop_duplicates().shape))
    return df_EnergyPatents


def main(test=False):
    LOGGER.info('SCRIPT: d_patent_ipccpc_transportation.py')
    LOGGER.info("       Collect info for patents that fall into cpc and ipc codes belonging to transportation")
    # Get lists of relevant CPC and IPC codes
    # Import the CPC/IPC codes we are interested in.
    df_greendirty_codes = pd.read_csv(PATHS.dropbox / 'Data_outputted/C_PatentVariables/CPC_IPC_codes_allsubgroups.csv')
    # Limit file to only sector = transportation. no need to include Electricity Transport because the upper code is already included in Transport (F02B)
    df_greendirty_codes = df_greendirty_codes[df_greendirty_codes['Sector'] == 'Transport']
    # Get list of appln_ids and docdb_ids that are associated with these codes
    list_appln_id = get_applnid_with_energy_codes(df_greendirty_codes, test)
    # Just to be sure, let's do a first pass through patstat to collect all the docdb ids corresponding to the appln_ids I have
    list_appln_id, list_docdb_id = collecting_patstat.collecting_ids(list_appln_id, [], test)
    namefile = 'ipc_cpc_transpo'
    LOGGER.info("       Collect info at application level")
    collecting_patstat.collecting_info_on_applnids(list_appln_id, list_docdb_id, namefile, test, small=True)
    LOGGER.info("       Aggregating patents at family level")
    family_aggregation.aggregate_at_familylevel(namefile)
    LOGGER.info("       Collect all ipc and cpc codes of patents and aggregate them at family level")
    collecting_patstat.extract_and_save_cpc_ipc_codes(namefile, test)
    LOGGER.info("       Collect naces codes associated to patents and aggregate them at family level")
    collecting_patstat.extract_and_save_nace_codes(namefile, test)
    LOGGER.info("       Collect psn_sector of applicants associated to patents and aggregate them at family level")
    collecting_patstat.extract_and_save_psnsector(list_appln_id, namefile, test)
    LOGGER.info('SCRIPT END: d_patent_ipccpc_transportation.py')
    return
